#include "MathUnitTests.h"

#include <cassert>

#include "math/MathConstants.h"
#include "math/Point2.h"
#include "math/Point3.h"
#include "math/Utils.h"
#include "math/Vector3.h"

namespace {
    void utilsTests() {
        // areEquals()        
        {
            const float num1 = 0.0f;
            const float num2 = 0.0f;
            const bool result = areEquals(num1, num2);
            assert(result == true);
        }

        // areEquals()
        {
            const float num1 = 1.0f;
            const float num2 = 0.0f;
            const bool result = areEquals(num1, num2);
            assert(result == false);
        }

        // unitVectorFromOrientation()
        {
            Vector3 outVec;
            float orientation = 0.0f;
            Vector3 vec(0.0f, 0.0f, 1.0f);
            unitVectorFromOrientation(orientation, outVec);
            assert(outVec == vec);

            orientation = HALF_PI;
            vec.set(1.0f, 0.0f, 0.0f);
            unitVectorFromOrientation(orientation, outVec);
            assert(outVec == vec);

            orientation = -HALF_PI;
            vec.set(-1.0f, 0.0f, 0.0f);
            unitVectorFromOrientation(orientation, outVec);
            assert(outVec == vec);

            orientation = PI;
            vec.set(0.0f, 0.0f, -1.0f);
            unitVectorFromOrientation(orientation, outVec);
            assert(outVec == vec);

            orientation = -PI;
            vec.set(0.0f, 0.0f, -1.0f);
            unitVectorFromOrientation(orientation, outVec);
            assert(outVec == vec);
        }

        // orientationFromVector()
        {
            const Vector3 vec1(0.0f, 0.0f, 10.0f);
            const float orientation1 = 0.0f;
            float orientation = orientationFromVector(vec1);
            assert(areEquals(orientation1, orientation) == true);

            const Vector3 vec2(1.0f, 0.0f, 0.0f);
            const float orientation2 = PI * 0.5f;
            orientation = orientationFromVector(vec2);
            assert(areEquals(orientation2, orientation) == true);

            const Vector3 vec3(-5.0f, 0.0f, 0.0f);
            const float orientation3 = PI * -0.5f;
            orientation = orientationFromVector(vec3);
            assert(areEquals(orientation3, orientation) == true);

            const Vector3 vec4(0.0f, 0.0f, -10.0f);
            const float orientation4 = PI;
            orientation = orientationFromVector(vec4);
            assert(areEquals(orientation4, orientation) == true);

            const Vector3 vec5(0.0f, 0.0f, -10.0f);
            const float orientation5 = -PI;
            orientation = orientationFromVector(vec5);
            assert(areEquals(orientation5, orientation) == false);
        }

        // rotationAngle()
        {
            const Vector3 src1(5.0f, 0.0f, -5.0f);
            const Vector3 dest1(-5.0f, 0.0f, -5.0f);
            float angle = rotationAngle(src1, dest1);
            const float angleRes1 = HALF_PI;
            assert(areEquals(angle, angleRes1) == true);

            angle = rotationAngle(dest1, src1);
            const float angleRes1a = -HALF_PI;
            assert(areEquals(angle, angleRes1a) == true);

            const Vector3 src2(0.0f, 0.0f, 0.0f);
            const Vector3 dest2(0.0f, 0.0f, 0.0f);
            angle = rotationAngle(src2, dest2);
            const float angleRes2 = 0.0f;
            assert(areEquals(angle, angleRes2) == true);

            const Vector3 src3(1.0f, 0.0f, 0.0f);
            const Vector3 dest3(1.0f, 0.0f, 0.0f);
            angle = rotationAngle(src3, dest3);
            const float angleRes3 = 0.0f;
            assert(areEquals(angle, angleRes3) == true);

            const Vector3 src4(1.0f, 0.0f, 0.0f);
            const Vector3 dest4(-1.0f, 0.0f, 0.0f);
            angle = rotationAngle(src4, dest4);
            const float angleRes4 = -PI;
            assert(areEquals(angle, angleRes4) == true);
        }
    }

    void point2Tests() {
        // Point2::Constructor:
        {
            const float x = 0.0f;
            const float y = 5.0f;
            const Point2 p(x, y);

            assert(areEquals(p.x, x) == true);
            assert(areEquals(p.y, y) == true);
        }

        // Point2::operator== and operator!=
        {
            const float x1 = 10.0f;
            const float y1 = 0.0f;
            const Point2 p1(x1, y1);

            const float x2 = -10.0f;
            const float y2 = 0.0f;
            const Point2 p2(x2, y2);

            assert(p1 != p2);
            assert(!(p1 == p2));
        }

        // set()
        {
            const float x1 = 10.0f;
            const float y1 = -5.0f;
            Point2 p;
            p.set(x1, y1);

            assert(areEquals(x1, p.x) == true);
            assert(areEquals(y1, p.y) == true);
        }
    }

    void point3Tests() {
        // Point3::Constructor
        {
            const float x = 0.0f;
            const float y = 5.0f;
            const float z = 10.0f;
            const Point3 p(x, y, z);

            assert(areEquals(p.x, x) == true);

            assert(areEquals(p.y, y) == true);

            assert(areEquals(p.z, z) == true);
        }

        // Point3::operator==() and operator!=()
        {
            const float x1 = 10.0f;
            const float y1 = 0.0f;
            const float z1 = -10.0f;
            const Point3 p1(x1, y1, z1);

            const float x2 = -10.0f;
            const float y2 = 0.0f;
            const float z2 = -10.0f;
            const Point3 p2(x2, y2, z2);

            assert(!(p1 == p2));
            assert(p1 != p2);
        }

        // Point3::operator+=() and operator-=()
        {
            Point3 p(0.0f, 0.0f, 0.0f);
            const Vector3 v(1.0f, 2.0f, 3.0f);

            p += v;
            assert(p == Point3(1.0f, 2.0f, 3.0f));

            p -= v;
            assert(p == Point3(0.0f, 0.0f, 0.0f));
        }

        // Point3::operator+() and operator-()
        {
            const Point3 p1(0.0f, 0.0f, 0.0f);
            const Vector3 v(1.0f, 2.0f, 3.0f);

            Point3 p2 = p1 + v;
            assert(p2 == Point3(1.0f, 2.0f, 3.0f));

            const Point3 p3 = p1 - v;
            assert(p3 == Point3(-1.0f, -2.0f, -3.0f));
        }

        // set();
        {
            Point3 p;
            p.set(1.0f, 2.0f, 3.0f);
            assert(p == Point3(1.0f, 2.0f, 3.0f));
        }

        // Point3::isZero()
        {
            const float x1 = 0.0f;
            const float y1 = 0.0f;
            const float z1 = 0.0f;
            const Point3 p1(x1, y1, z1);

            const float x2 = 10.0f;
            const float y2 = 0.0f;
            const float z2 = -10.0f;
            const Point3 p2(x2, y2, z2);

            assert(p1.isZero() == true);

            assert(p2.isZero() == false);
        }
    }

    void vector3Tests() {
        // Vector3::Constructor(x, y, z)
        {
            const float x = 0.0f;
            const float y = 5.0f;
            const float z = 10.0f;
            const Vector3 vec(x, y, z);

            assert(areEquals(vec.x, x) == true);

            assert(areEquals(vec.y, y) == true);

            assert(areEquals(vec.z, z) == true);
        }

        // Vector3::Constructor(from, to)
        {
            const Point3 from(-10.0f, 0.0f, 0.0f);
            const Point3 to(10.0f, 10.0f, 10.0f);
            const Vector3 v(from, to);
            assert(v == Vector3(20.0f, 10.0f, 10.0f));
        }

        // Vector3::operator==() and operator!=
        {
            const Vector3 v1(10.0f, 0.0f, -1.0f);
            const Vector3 v2(0.0f, 1.0f, 2.0f);
            assert(v1 != v2);
            assert(!(v1 == v2));
        }

        // Vector3::operator+=() and operator-=()
        {
            Vector3 v1(1.0f, 1.0f, 1.0f);
            const Vector3 v2(2.0f, 0.0f, 2.0f);
            v1 += v2;
            assert(v1 == Vector3(3.0f, 1.0f, 3.0f));
            v1 -= v2;
            assert(v1 == Vector3(1.0f, 1.0f, 1.0f));
        }        

        // Vector3::set(x, y, z)
        {
            Vector3 v;
            v.set(1.0f, 2.0f, 3.0f);
            assert(v == Vector3(1.0f, 2.0f, 3.0f));
        }

        // Vector3::set(from, to)
        {
            const Point3 from(1.0f, 2.0f, 3.0f);
            const Point3 to (2.0f, 3.0f, 4.0f);
            Vector3 v;
            v.set(from, to);
            assert(v == Vector3(1.0f, 1.0f, 1.0f));
        }

        // Vector3::scale()
        {
            const float x1 = -10.0f;
            const float y1 = 0.0f;
            const float z1 = 50.0f;
            Vector3 vec1(x1, y1, z1);

            const float x2 = -20.0f;
            const float y2 = 0.0f;
            const float z2 = 100.0f;
            const Vector3 vec(x2, y2, z2);

            const float s = 2.0f;

            vec1 *= s;

            assert(vec1 == vec);
        }

        // Vector3::dot()
        {
            const float x1 = 1.0f;
            const float y1 = 0.0f;
            const float z1 = 0.0f;
            const Vector3 vec1(x1, y1, z1);

            const float x2 = 0.0f;
            const float y2 = 0.0f;
            const float z2 = 1.0f;
            const Vector3 vec2(x2, y2, z2);

            const float x3 = -1.0f;
            const float y3 = 0.0f;
            const float z3 = 0.0f;
            const Vector3 vec3(x3, y3, z3);

            float dot = 0.0f;
            dot = dotProduct(vec1, vec2);
            assert(areEquals(dot, dot) == true);

            dot = dotProduct(vec2, vec1);
            assert(areEquals(dot, dot) == true);

            dot = 1.0f;
            dot = dotProduct(vec1, vec1);
            assert(areEquals(dot, dot) == true);

            dot = -1.0f;
            dot = dotProduct(vec1, vec3);
            assert(areEquals(dot, dot) == true);
        }

        // Vector3::cross
        {
            const float x1 = 1.0f;
            const float y1 = 0.0f;
            const float z1 = 0.0f;
            const Vector3 vec1(x1, y1, z1);

            const float x2 = 0.0f;
            const float y2 = 0.0f;
            const float z2 = 1.0f;
            const Vector3 vec2(x2, y2, z2);

            const float x3 = 0.0f;
            const float y3 = -1.0f;
            const float z3 = 0.0f;
            Vector3 vec3(x3, y3, z3);

            const float x4 = 0.0f;
            const float y4 = 0.0f;
            const float z4 = 0.0f;
            const Vector3 vec4(x4, y4, z4);

            Vector3 vec;
            crossProduct(vec1, vec2, vec);
            assert(vec == vec3);

            vec3 *= -1.0f;
            crossProduct(vec2, vec1, vec);
            assert(vec == vec3);

            crossProduct(vec1, vec1, vec);
            assert(vec == vec4);
        }

        // Vector3::normalize()
        {
            const float x1 = 129.0f;
            const float y1 = -14.0f;
            const float z1 = 0.0f;
            Vector3 vec1(x1, y1, z1);

            normalize(vec1, vec1);
            const float len = vec1.length();
            assert(areEquals(len, 1.0f) == true);
        }

        // Vector3::isZero()
        {
            const float x1 = 0.0f;
            const float y1 = 0.0f;
            const float z1 = 0.0f;
            const Vector3 vec1(x1, y1, z1);

            const float x2 = 10.0f;
            const float y2 = 0.0f;
            const float z2 = -10.0f;
            const Vector3 vec2(x2, y2, z2);

            assert(vec1.isZero() == true);
            assert(vec2.isZero() == false);
        }    

        // Vector3::length()
        {
            const float len = 15.0f;

            const float x1 = len;
            const float y1 = 0.0f;
            const float z1 = 0.0f;
            const Vector3 vec1(x1, y1, z1);

            const float x2 = 0.0f;
            const float y2 = len;
            const float z2 = 0.0f;
            const Vector3 vec2(x2, y2, z2);

            const float x3 = 0.0f;
            const float y3 = 0.0f;
            const float z3 = len;
            const Vector3 vec3(x3, y3, z3);

            float newLen = vec1.length();
            assert(areEquals(len, newLen) == true);

            newLen = vec2.length();
            assert(areEquals(len, newLen) == true);

            newLen = vec3.length();
            assert(areEquals(len, newLen) == true);
        }

        // Vector3::setLength()
        {
            const float len = 121.0f;

            const float x1 = 129.0f;
            const float y1 = -14.0f;
            const float z1 = 0.0f;
            Vector3 vec1(x1, y1, z1);

            float newLen = vec1.length();
            assert(areEquals(len, newLen) == false);

            vec1.setLength(len);

            newLen = vec1.length();
            assert(areEquals(len, newLen) == true);
        }

        // Vector3::sqrLength()
        {
            const float len = 15.0f;
            const float sqrLen = len * len;

            const float x1 = len;
            const float y1 = 0.0f;
            const float z1 = 0.0f;
            const Vector3 vec1(x1, y1, z1);

            const float x2 = 0.0f;
            const float y2 = len;
            const float z2 = 0.0f;
            const Vector3 vec2(x2, y2, z2);

            const float x3 = 0.0f;
            const float y3 = 0.0f;
            const float z3 = len;
            const Vector3 vec3(x3, y3, z3);

            float length = vec1.sqrLength();
            assert(areEquals(sqrLen, length) == true);

            length = vec2.sqrLength();
            assert(areEquals(sqrLen, length) == true);

            length = vec3.sqrLength();
            assert(areEquals(sqrLen, length) == true);
        }
    }
}

namespace mathUnitTests {
    void runTests() {
        utilsTests();
        point2Tests();
        point3Tests();
        vector3Tests();
    }
}